﻿<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Neural Network Demo</title>

    <script type="module" src="index.mjs"></script>

    <link rel="script" href="demo.js">
    <link rel="stylesheet" href="demo.css">
</head>

<body>
    <script type="module">
        import { torch } from './index.mjs';
        window.torch = torch;
    </script>

    <div style="height: 20px;"></div>
    <div class="container">
        <div class="separator" id="right-separator" style="display: block; padding-bottom: 0px;">
            <img style="display: block;" class='icon' src="../../assets/icon_gray_bg.png">
            <a target="_blank" href="https://github.com/eduardoleao052/js-pytorch"> <img class="logos" src="../../assets/github.png" style="width: 25px; height: 25px; display: inline-block; margin-bottom: 5px;"></a>
            <a target="_blank" href="https://www.linkedin.com/in/eduardo-leao-368bb6259/"> <img class="logos" src="../../assets/linkedin.png" style="width: 25px; height: 25px; display: inline-block; margin-bottom: 5px;"></a>
        </div>
    
        <div class="separator" id="right-separator" style="margin-top: 10px; display: block;">
            <p style="display: block; font-size: 16px; color: #292929; margin-bottom: 25px;">Welcome to <b>JS-Torch's Web Demo</b>! You can choose the <b>Model Hyperparameters</b> on the left, set the <b>Model Layers</b> on the right (number of layers and hidden dimension on each). </p>
            <ul> 
                <li>This model was trained on a <b>Dummy dataset</b>, composed of <b>randomly generated images</b>. </li>
                <br>
                <li>The batch size is <b>8</b>, and the images have a dimension of <b>32x32</b>. </li>
                <br>
                <li>The output is a number from 0 to 31.</li>
                <br>
                <li>The Model is trained using an <b>Adam Optimizer</b> and a <b>Cross Entropy Loss</b>.</li>
            </ul>
        </div>
    </div>

    <div class="container" id="upper-container">
        <div class="left-div">
            <h2 style="display: inline-block; margin-left: 20px; margin-top: 2px; margin-bottom: 10px;">Training Parameters</h2>
            <div class="separator" id="left-separator">

                <label for="batch-size">Batch Size:</label>
                <input type="number" id="batch-size" name="batch-size" min="1" value="8" max="64" onchange="checkInput(this)"> 

                <label for="learning-rate">Learning Rate:</label>
                <input type="number" id="learning-rate" name="learning-rate" step="0.001" min="0" value="0.001" required>

                <label for="regularization">L2 Regularization:</label>
                <input type="number" id="regularization" name="regularization" step="0.001" min="0" value="0.001">

                <label for="beta1">Beta 1 (Adam):</label>
                <input type="number" id="beta1" name="beta1" step="0.001" min="0" value="0.99">

                <label for="beta2">Beta 2 (Adam):</label>
                <input type="number" id="beta2" name="beta2" step="0.0001" min="0" value="0.999">

                <label for="epsilon">Epsilon (Adam):</label>
                <input type="number" id="epsilon" name="epsilon" step="0.000001" min="0" value="0.000001">

            </div>
        </div>
        <div class="right-div">
            <div>
                <h2 style="display: inline-block; margin-right: 15px; margin-left: 20px;">Model Layers</h2>
                <button class='layer-button' onclick="addBox()">+</button>
                <button class='layer-button' onclick="removeBox()">-</button>
                <div class="separator" id="layersBox">
                    
                </div>
            </div>
            <div id="training-buttons">
                <button id="start-button" onclick="trainLoopInitializer()">Train</button>
                <button id="stop-button" onclick="pauseTraining()">Pause</button>
                <button id="reset-button" onclick="resetTraining()">Reset</button>
            </div>
        </div>
    </div>

    <div class="container" style="padding-bottom: 0px;">
        <h2>Graph</h2>
        <!-- This is where the graph will be displayed -->
        <canvas id="graph" width="700" height="350"></canvas>
        <div style="display: inline-block; width: 90%; margin-bottom: 15px;">
            <p id="iter"> <b>Iteration:</b> </p>
            <p id="total-visited"> <b>Total Training Examples:</b> </p>
            <p id="loss"> <b>Loss:</b> </p>
        </div>
        <divs tyle="display: inline-block;">
            <a target="_blank" href="https://github.com/eduardoleao052/js-pytorch"><img class="logos" src="../../assets/github.png"  style="width: 25px; height: 25px; display: inline-block; margin-bottom: 0px;"></a>
            <a target="_blank" href="https://www.linkedin.com/in/eduardo-leao-368bb6259/"><img class="logos" src="../../assets/linkedin.png"  style="width: 25px; height: 25px; display: inline-block; margin-bottom: 0px;"></a>
        </div>
    </div>

    <script>
        let boxCount = [];
        let data = [];
        let training = false;
        let in_loop = true;
        let overFlow = 1;
        let iter = 0;
        let total_visited = 0;

        function addBox() {
            if (boxCount.length < 5 && !training) {
                // Find container to add Boxes in:
                const container = document.getElementById('layersBox');
                // Create the box:
                const newBox = document.createElement('div');
                newBox.id = `box-div-${boxCount.length}`
                newBox.classList.add('box');
                // Create a numeric input field inside the box (for number of neurons):
                const inputField = document.createElement('input');
                inputField.type = 'number';
                // The next Box has as many neurons as last current Box:
                inputField.value = boxCount[boxCount.length -1] || '10';
                inputField.max = '100';
                inputField.idx = boxCount.length;
                inputField.id = `box-input`;
                inputField.onchange = function() {
                    changeDim(this, boxCount);
                };
                newBox.appendChild(inputField);
                container.appendChild(newBox);
                boxCount.push(Number(inputField.value));
            };
        };

        function removeBox() {
            if (boxCount.length > 1 && !training) {
                const container = document.getElementById('layersBox');
                container.removeChild(container.lastChild);
                boxCount.pop();
            };
        };

        function changeDim(el, boxCount) {
            if (el.value > 100) {
                el.value = 100;
            };
            boxCount[el.idx] = Number(el.value);                
        };

        function checkInput(el){
            if (el.value > 32) {
                el.value = 32;
            };
        }

        function getBatch(x, y, batch_size) {
            // Instantiate x_batch and y_batch as empty tensors:
            let x_batch = [];
            let y_batch = [];
            // Iteratively add instances to batch:
            for (let i=0 ; i < batch_size ; i++) {
                p = Math.floor(Math.random() * x.length);
                x_batch.push(x.data[p])
                y_batch.push(y.data[p])
            };
            return [torch.tensor(x_batch), torch.tensor(y_batch)];
        };

        function trainLoopInitializer() {
            in_loop = true;
            trainLoop();
        };

        function trainLoop() {
            if (!in_loop) return;
            trainStep();
            setTimeout(trainLoop, 0.01);
        };
       
        function trainStep() {
            if (!training) {
                training = true;
                buttons = document.getElementsByClassName('layer-button');
                for (button of buttons) {
                    button.style.backgroundColor = '#0056b3';
                };
                // Implement dummy torch.nn.Module class:
                class NeuralNet extends torch.nn.Module {
                    constructor() {
                        super();
                        // Instantiate Neural Network's Layers:
                        this.wIn = new torch.nn.Linear(32,boxCount[0]);
                        this.reluIn = new torch.nn.ReLU();
                        if (boxCount.length > 1) {
                            this.w1 = new torch.nn.Linear(boxCount[0],boxCount[1]);
                            this.relu1 = new torch.nn.ReLU();
                        } if (boxCount.length > 2) {
                            this.w2 = new torch.nn.Linear(boxCount[1],boxCount[2]);
                            this.relu2 = new torch.nn.ReLU();
                        } if (boxCount.length > 3) {
                            this.w3 = new torch.nn.Linear(boxCount[2],boxCount[3]);
                            this.relu3 = new torch.nn.ReLU();
                        } if (boxCount.length > 4) {
                            this.w4 = new torch.nn.Linear(boxCount[3],boxCount[4]);
                            this.relu4 = new torch.nn.ReLU();
                        };
                        this.wOut = new torch.nn.Linear(boxCount[boxCount.length-1], 32);
                        
                        this.reluOut = new torch.nn.ReLU();

                    };

                    forward(x) {
                        let z;
                        z = this.wIn.forward(x);
                        z = this.reluIn.forward(z);
                        
                        if (boxCount.length > 1) {
                            z = this.w1.forward(z);
                            z = this.relu1.forward(z);
                        } if (boxCount.length > 2) {
                            z = this.w2.forward(z);
                            z = this.relu2.forward(z);
                        } if (boxCount.length > 3) {
                            z = this.w3.forward(z);
                            z = this.relu3.forward(z);
                        } if (boxCount.length > 4) {
                            z = this.w4.forward(z);
                            z = this.relu4.forward(z);
                        };
                        z = this.wOut.forward(z);
                        z = this.reluOut.forward(z);

                        return z;
                    };
                };
                globalThis.model = new NeuralNet();

                // Define loss function and optimizer:
                globalThis.loss_func = new torch.nn.CrossEntropyLoss();
                
                // Instantiate input and output:
                globalThis.x = torch.randn([8,32,32]);
                globalThis.y = torch.randint(0,32,[8,32]);

            }; 

            // Get live learning rate and regularization values.
            let batch_size = Number(document.getElementById('batch-size').value)
            let lr = Number(document.getElementById('learning-rate').value)
            let reg = Number(document.getElementById('regularization').value )
            let beta1 = Number(document.getElementById('beta1').value )
            let beta2 = Number(document.getElementById('beta2').value )
            let eps = Number(document.getElementById('epsilon').value )

            // Build optimizer:
            let optimizer = new torch.optim.Adam(model.parameters(), lr=lr, reg=reg, betas=[beta1, beta2], eps=eps)
            let loss;

            // Training Loop:
            for(let i=0 ; i < 1 ; i++) {
                let [x_batch, y_batch] = getBatch(x, y, batch_size)

                let z = model.forward(x_batch)

                // Get loss:
                loss = loss_func.forward(z, y_batch)

                // Backpropagate the loss using neuralforge.tensor's backward() method:
                loss.backward()

                // Update the weights:
                optimizer.step()
                
                // Reset the gradients to zero after each training step:
                optimizer.zero_grad()

                // If loss went to infinity (model way too large for training size), represent that in the graph:
                if (isNaN(loss.data[0])) {
                    overFlow = overFlow * 1.5;
                    data.push(overFlow + (Math.random() - 0.5) * 15);
                // If not, just keep adding the loss to the graph:
                } else {data.push(loss.data)}
                
                // Display iteration and loss on the screen:
                document.getElementById('iter').innerHTML = `<b>Iteration:</b> ${iter}`;
                document.getElementById('total-visited').innerHTML = `<b>Total Training Examples:</b> ${total_visited}`;
                document.getElementById('loss').innerHTML = `<b>Loss:</b> ${loss.data[0].toFixed(3)}`;
                iter += 1;
                total_visited += batch_size;
            
                plotGraph()
            };
        };

        function pauseTraining() {
            in_loop = false;
        };

        function resetTraining() {
            in_loop = false;
            training = false;
            iter = 0;
            total_visited = 0;
            data = [];
            plotGraph();
            buttons = document.getElementsByClassName('layer-button');
            for (button of buttons) {
                button.style.backgroundColor = '';

            };
        };

        function plotGraph() {
            var canvas = document.getElementById('graph');
            var ctx = canvas.getContext('2d');
            ctx.clearRect(0, 0, canvas.width, canvas.height);

            var startX = 50;
            var startY = canvas.height - 50;
            var maxX = data.length - 1;
            var maxY = Math.max(...data);
            var stepX = (canvas.width - 2 * startX) / maxX;
            var stepY = (startY - 50) / maxY;

            // Set default stroke style and line width
            ctx.strokeStyle = 'black';
            ctx.lineWidth = 1;

            // Draw x-axis label
            ctx.fillStyle = 'black';
            ctx.font = '12px Arial';
            ctx.textAlign = 'center';
            ctx.fillText('Iterations', canvas.width / 2, startY + 40);

            // Draw y-axis label
            ctx.save();
            ctx.translate(8, canvas.height / 2);
            ctx.rotate(-Math.PI / 2);
            ctx.fillStyle = 'black';
            ctx.font = '12px Arial';
            ctx.textAlign = 'center';
            ctx.fillText('Cross-Entropy Loss', 0, 0);
            ctx.restore();

            // Draw grid lines
            ctx.beginPath();
            for (var i = 1; i < maxY; i += 1) {
                var y = startY - i * stepY;
                ctx.moveTo(startX, y);
                ctx.lineTo(canvas.width - startX, y);
            };
            ctx.stroke();

            // Draw x-axis
            ctx.beginPath();
            ctx.moveTo(startX, startY);
            ctx.lineTo(canvas.width - startX, startY);
            ctx.stroke();

            // Draw y-axis
            ctx.beginPath();
            ctx.moveTo(startX, startY);
            ctx.lineTo(startX, 50);
            ctx.stroke();

            // Draw ticks and labels on x-axis
            for (var i = 0; i <= maxX; i++) {
                var x = startX + i * stepX * 100;
                ctx.beginPath();
                ctx.moveTo(x, startY);
                ctx.lineTo(x, startY + 5);
                ctx.stroke();
                ctx.fillText((i*100).toString(), x - 5, startY + 20);
            };

            // Draw ticks and labels on y-axis
            for (var i = 0; i <= maxY; i += 1) {
                var y = startY - i * stepY;
                ctx.beginPath();
                ctx.moveTo(startX, y);
                ctx.lineTo(startX - 5, y);
                ctx.stroke();
                ctx.fillText((i * (maxY / 4)).toFixed(2), startX - 30, y + 5);
            };

            // Draw line plot
            ctx.beginPath();
            ctx.strokeStyle = 'blue';
            ctx.lineWidth = 2;
            ctx.moveTo(startX, startY - data[0] * stepY);
            for (var i = 1; i <= maxX; i++) {
                var x = startX + i * stepX;
                var y = startY - data[i] * stepY;
                ctx.lineTo(x, y);
            };
            ctx.stroke();
        };

        // Initial setup
        addBox(); // Add one box initially
    </script>
</body>

</html>
